{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Interoperation with other JAX frameworks\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/brainpy/blob/master/docs_version2/tutorial_advanced/interoperation.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs_version2/tutorial_advanced/interoperation.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BrainPy is designed to be easily interoperated with other JAX frameworks."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "pycharm": {
     "is_executing": true
    },
    "ExecuteTime": {
     "end_time": "2025-10-06T03:33:37.549326Z",
     "start_time": "2025-10-06T03:33:33.404118Z"
    }
   },
   "source": [
    "import brainpy as bp\n",
    "import brainpy_datasets as bd"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-06T03:33:37.556584Z",
     "start_time": "2025-10-06T03:33:37.552476Z"
    }
   },
   "source": [
    "import brainpy.math as bm\n",
    "import jax.numpy as jnp\n",
    "import numpy as np"
   ],
   "outputs": [],
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. data are exchangeable among different frameworks.\n",
    "This can be realized because ``Array`` can be direactly converted to JAX ndarray or NumPy ndarray."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert a ``Array`` into a JAX ndarray."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-06T03:33:37.800157Z",
     "start_time": "2025-10-06T03:33:37.562302Z"
    }
   },
   "source": [
    "b = bm.random.randint(10, size=5)"
   ],
   "outputs": [],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-06T03:33:38.703173Z",
     "start_time": "2025-10-06T03:33:37.806877Z"
    }
   },
   "source": [
    "# Array.value is a JAX's DeviceArray\n",
    "b.value"
   ],
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'jaxlib._jax.ArrayImpl' object has no attribute 'value'",
     "output_type": "error",
     "traceback": [
      "\u001B[31m---------------------------------------------------------------------------\u001B[39m",
      "\u001B[31mAttributeError\u001B[39m                            Traceback (most recent call last)",
      "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[4]\u001B[39m\u001B[32m, line 2\u001B[39m\n\u001B[32m      1\u001B[39m \u001B[38;5;66;03m# Array.value is a JAX's DeviceArray\u001B[39;00m\n\u001B[32m----> \u001B[39m\u001B[32m2\u001B[39m b.value\n",
      "\u001B[31mAttributeError\u001B[39m: 'jaxlib._jax.ArrayImpl' object has no attribute 'value'"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert a ``Array`` into a numpy ndarray."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([9, 9, 0, 4, 7])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Array can be easily converted to a numpy ndarray\n",
    "np.asarray(b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert a numpy ndarray into a ``Array``."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([0, 1, 2, 3, 4], dtype=int32)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bm.asarray(np.arange(5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert a JAX ndarray into a ``Array``."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([0, 1, 2, 3, 4], dtype=int32)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bm.asarray(jnp.arange(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([0, 1, 2, 3, 4], dtype=int32)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bm.Array(jnp.arange(5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. other JAX frameworks can be integrated into a BrainPy program."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we use the [Flax](https://github.com/google/flax), a library used for deep neural networks, to define a convolutional neural network (CNN). The, we integrate this CNN model into our RNN model which defined by BrainPy's syntax."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we first use **flax** to define a CNN network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flax import linen as nn\n",
    "\n",
    "class CNN(nn.Module):\n",
    "  \"\"\"A CNN model implemented by using Flax.\"\"\"\n",
    "\n",
    "  @nn.compact\n",
    "  def __call__(self, x):\n",
    "    x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n",
    "    x = nn.relu(x)\n",
    "    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
    "    x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n",
    "    x = nn.relu(x)\n",
    "    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
    "    x = x.reshape((x.shape[0], -1))  # flatten\n",
    "    x = nn.Dense(features=256)(x)\n",
    "    x = nn.relu(x)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, we define an RNN model by using our BrainPy interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Network(bp.DynamicalSystemNS):\n",
    "  \"\"\"A network model implemented by BrainPy\"\"\"\n",
    "\n",
    "  def __init__(self):\n",
    "    super(Network, self).__init__()\n",
    "    self.cnn = bp.interop.FromFlax(CNN(), bm.ones([1, 4, 28, 1]))\n",
    "    self.rnn = bp.layers.GRUCell(256, 100)\n",
    "    self.linear = bp.layers.Dense(100, 10)\n",
    "\n",
    "  def update(self, x):\n",
    "    x = self.cnn(x)\n",
    "    x = self.rnn(x)\n",
    "    x = self.linear(x)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We initialize the network, optimizer, loss function, and BP trainer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = Network()\n",
    "opt = bp.optim.Momentum(0.1)\n",
    "\n",
    "def loss_func(predictions, targets):\n",
    "  logits = bm.max(predictions, axis=1)\n",
    "  loss = bp.losses.cross_entropy_loss(logits, targets)\n",
    "  accuracy = bm.mean(bm.argmax(logits, -1) == targets)\n",
    "  return loss, {'accuracy': accuracy}\n",
    "\n",
    "trainer = bp.BPTT(net, loss_fun=loss_func, optimizer=opt, loss_has_aux=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We get the MNIST dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = bd.vision.MNIST(r'D:\\data', download=True)\n",
    "X = train_dataset.data.reshape((-1, 7, 4, 28, 1)) / 255\n",
    "Y = train_dataset.targets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, train our defined model by using ``BPTT.fit()`` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train 100 steps, use 32.5824 s, train loss 0.96465, accuracy 0.66015625\n",
      "Train 200 steps, use 30.9035 s, train loss 0.38974, accuracy 0.89453125\n",
      "Train 300 steps, use 33.1075 s, train loss 0.31525, accuracy 0.890625\n",
      "Train 400 steps, use 31.4062 s, train loss 0.23846, accuracy 0.91015625\n",
      "Train 500 steps, use 32.3371 s, train loss 0.21995, accuracy 0.9296875\n",
      "Train 600 steps, use 32.5692 s, train loss 0.20885, accuracy 0.92578125\n",
      "Train 700 steps, use 33.0139 s, train loss 0.24748, accuracy 0.90625\n",
      "Train 800 steps, use 31.9635 s, train loss 0.14563, accuracy 0.953125\n",
      "Train 900 steps, use 31.8845 s, train loss 0.17017, accuracy 0.94140625\n",
      "Train 1000 steps, use 32.0537 s, train loss 0.09413, accuracy 0.95703125\n",
      "Train 1100 steps, use 32.3714 s, train loss 0.06015, accuracy 0.984375\n",
      "Train 1200 steps, use 31.6957 s, train loss 0.12061, accuracy 0.94921875\n",
      "Train 1300 steps, use 31.8346 s, train loss 0.13908, accuracy 0.953125\n",
      "Train 1400 steps, use 31.5252 s, train loss 0.10718, accuracy 0.953125\n",
      "Train 1500 steps, use 31.7274 s, train loss 0.07869, accuracy 0.96875\n",
      "Train 1600 steps, use 32.3928 s, train loss 0.08295, accuracy 0.96875\n",
      "Train 1700 steps, use 31.7718 s, train loss 0.07569, accuracy 0.96484375\n",
      "Train 1800 steps, use 31.9243 s, train loss 0.08607, accuracy 0.9609375\n",
      "Train 1900 steps, use 32.2454 s, train loss 0.04332, accuracy 0.984375\n",
      "Train 2000 steps, use 31.6231 s, train loss 0.02369, accuracy 0.9921875\n",
      "Train 2100 steps, use 31.7800 s, train loss 0.03862, accuracy 0.9765625\n",
      "Train 2200 steps, use 31.5431 s, train loss 0.01871, accuracy 0.9921875\n",
      "Train 2300 steps, use 32.1064 s, train loss 0.03255, accuracy 0.9921875\n"
     ]
    }
   ],
   "source": [
    "trainer.fit([X, Y], batch_size=256, num_epoch=10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
